%% matlab code for doing the modeling and analysis going into the paper "Enhancing Epidemic Forecasting: Learning from COVID-19 Models "


%% controls for the code

genCode=0;  %generate code and run the analysis in Vensim
genGraphs=0; %generate primary graphs comparing models
genRegs=0;  %generate regression based results
needData=0; %will be reading data from scratch or upload from past matlab data files (provided)
readData=0; %read data from raw tab files outputted by Vensim (not provided)
genRegDt=0; %generate regressions and store data for primary ranking graph

testPF=0; %whether to test for payoff improvement 
reDoCheck=0; %whether check for redoing missing tab files


%% basic parameters of analysis
drvNm='D:\Dropbox (Personal)\SEIR forecast\Main model stocastic\V16-Analysis\'; %rename to the folder where analysis is conducted
cd(drvNm);

inputModel='SEIRB-V16';  %name of the mdl Vensim model
cmdNm='SinglCmnd.cmd';
daysall=flip(200+7*(0:45));%the dates for calibration/projection, with October 15, 2019 being day 0
lkhrz=7*(1:20);
vocList={'SEIRB-V12-SEIRb-NoB.voc','SEIRB-V12-SEIRb-NoW.voc','SEIRB-V12-SEIRb.voc'};
vocNoRst=[0 0 1];
vocNm={'SEIRb-NoB','SEIRb-NoW','SEIRb','SEIRb-NoRst'};
vpdList={'NBLLPayoff.vpd'};
truthNm='incidence death projection';
frcsFl='ForecastData\'; %the folder where raw data is provided

stringToChange='Wyoming'; %the name of the placeholder state in the inputModel, to be replaced by other states
%
replacements={'Alabama','Alaska','Arizona','Arkansas','California','Colorado','Connecticut','Delaware',...
    'District of Columbia','Florida','Georgia','Guam','Hawaii','Idaho','Illinois','Indiana','Iowa','Kansas',...
    'Kentucky','Louisiana','Maine','Maryland','Massachusetts','Michigan','Minnesota','Mississippi','Missouri',...
    'Montana','Nebraska','Nevada','New Hampshire','New Jersey','New Mexico','New York','North Carolina','North Dakota',...
    'Ohio','Oklahoma','Oregon','Pennsylvania','Puerto Rico','Rhode Island','South Carolina','South Dakota','Tennessee',...
    'Texas','Utah','Vermont','Virginia','Washington','West Virginia','Wisconsin','Wyoming'};
%}
venDir='D:\Dropbox (MIT)\COVID-19-TestHealth\Models and Analysis\Tools\Vengine20210415\'; %the directory in which Vensim software is available
venNm='Vensim - vengine.exe'; %the name of the vensim software 


if genCode
    %% generate vensim models
    
    venModNm=genVensimModels(inputModel,stringToChange,replacements);
    %% create the basic structure to go into cmd generator
    
    L.payoff='';
    L.sensitivity='';
    L.optparm='';
    L.savelist='';
    L.senssavelist='';
    L.data='';
    L.changes='';
    L.runname='';
    L.setvals='';
    L.runtype='';
    L.export='';
    L.exit=0;
    
    % vars to track performance of calibrations
    %
    pfval=-1e10*ones(numel(replacements),numel(daysall),numel(vocList));
    params=cell(numel(replacements),numel(daysall),numel(vocList));
    dayssr=zeros(numel(replacements),numel(daysall),numel(vocList));
    alg=zeros(numel(replacements),numel(daysall),numel(vocList));
    states=zeros(numel(replacements),numel(daysall),numel(vocList));
    %}
    %% generate CMD in loops
    runNm2='test';
    for l=1:numel(replacements)
        
        for i=1:numel(daysall)
            
            for j=1:numel(vocList)
                runNm=[num2str(daysall(i)) regexprep(replacements{l},' ','-') '_' vocNm{j}];
                
                
                
                LIn=L;
                LIn.payoff={vpdList{1}};
                if i==1
                    LIn.optparm={[regexprep(vocList{j},'.voc','First') '.voc']};
                else
                    runNmP=[num2str(daysall(i-1)) regexprep(replacements{l},' ','-') '_' vocNm{j}];
                    LIn.changes={[runNmP '.out']};
                    LIn.optparm={vocList{j}};
                end
                
                LIn.setvals={['Last Estimation Time=' num2str(daysall(i))],['Final Time=' num2str(daysall(i))]};
                LIn.runname={runNm};
                LIn.runtype={'optimize'};
                writeCMD(venModNm{l},LIn,cmdNm,'w')
                
                %setting the no-reset simulation
                if vocNoRst(j)==1
                    runNm2=[runNm '-NoRst'];
                    LIn2=L;
                    LIn2.changes={[runNm '.out']};
                    LIn2.setvals={['Last Estimation Time=' num2str(daysall(i))],'ResetStates=0'};
                    LIn2.runname={runNm2};
                    LIn2.runtype={'run'};
                    LIn2.export={'V5saveVars.lst','*'}; %this leads to transposed values
                    writeCMD([],LIn2,cmdNm,'a');
                end
                
                % setting the primary simulation
                LIn2=L;
                LIn2.changes={[runNm '.out']};
                LIn2.setvals={['Last Estimation Time=' num2str(daysall(i))]};
                LIn2.runname={runNm};
                LIn2.runtype={'run'};
                LIn2.export={'V5saveVars.lst','*'}; %this leads to transposed values
                
                LIn2.exit=1;
                writeCMD([],LIn2,cmdNm,'a');
                
                if isfile([runNm '.tab']) & reDoCheck
                    sf=dir([runNm '.tab']);
                    f_size=sf.bytes;
                    if f_size>5000
                        success=1;
                    else
                        success=0;
                    end
                else
                    
                    success=0;
                end
                trials=0;
                while ~success && trials<4
                    trials=trials+1;
                    delete([runNm '.*']);
                    runVenCmnd(cmdNm,venNm,venDir,[runNm '.tab'],[runNm '.log'],3);
                    if isfile([runNm '.tab'])
                        sf=dir([runNm '.tab']);
                        f_size=sf.bytes;
                        if f_size>5000
                            successF=true;
                        else
                            successF=false;
                        end
                    else
                        
                        successF=false;
                    end
                    
                    if isfile([runNm '.out'])
                        [params{l,i,j}, ~,pfval(l,i,j),~,~]=parsCalibPars([runNm '.out']);
                        if i==1 | ~testPF
                            success=successF;
                        else
                            success=pfval(l,i,j)>pfval(l,i-1,j) & successF;
                        end
                    end
                end
                
                
                delete([runNm '_endpoint.*']);
                delete([runNm '_startpoint.*']);
                delete([runNm '.log']);
                delete([runNm '*.rep']);
                delete([runNm '*.err']);
                delete([runNm2 '.rep']);
                delete([runNm2 '*.err']);
                delete([runNm '*.vdf']);
                delete([runNm2 '*.vdf']);
                dayssr(l,i,j)=i;
                alg(l,i,j)=j;
                states(l,i,j)=l;
            end
            
        end
    end
    fclose('all');
    tcalib=table(reshape(pfval,[],1),reshape(dayssr,[],1),reshape(alg,[],1),reshape(states,[],1));
    tcalib.Properties.VariableNames={'payoff','days','alg','state'};
    save('calibOutcomes','tcalib','params');
end





%% read data from vensim and other predictions
if needData
    if readData
        FIPS=readtable([frcsFl 'Fips.xlsx']);
        Tt=[];
        varNums=5;
        opts = delimitedTextImportOptions('NumVariables',varNums,'Delimiter','\t','VariableNamesLine',1,'DataLines',2);
        opts.VariableTypes(1:end)={'double'};
        
        
        for i=1:numel(daysall)
            for j=1:numel(vocNm)
                usVls=[];
                usTrh=[];
                for l=1:numel(replacements)
                    runNm=[num2str(daysall(i)) regexprep(replacements{l},' ','-') '_' vocNm{j}];
                    if isfile([runNm '.tab'])
                        DTA=readtable([runNm '.tab'],opts);
                        varNms=DTA.Properties.VariableNames;
                        
                        inx1=contains(varNms,['OutputsOverTime_' regexprep(replacements{l},' ','') '_Death'],'IgnoreCase',true);
                        inx2=contains(varNms,['DataFlowExport_' regexprep(replacements{l},' ','') '_Death'],'IgnoreCase',true);
                        preds=DTA{:,inx1}';
                        tms=DTA.Time(:)';
                        trth=DTA{:,inx2}';
                        inx=~isnan(tms);
                        fip=FIPS.FIPS(matches(FIPS.Name,replacements{l}));
                        tbl=table({tms(inx)},{preds(inx)},{trth(inx)},fip,daysall(i),{vocNm{j}},l,'VariableNames',{'t','Preds','trth','fips','PredictDay','Model','state'});
                        if isempty(Tt)
                            Tt=tbl;
                        else
                            Tt=[Tt;tbl];
                        end
                        usVls=[usVls;preds(inx)];
                        usTrh=[usTrh;trth(inx)];
                    end
                    
                end
                tbl=table({tms(inx)},{sum(usVls,1)},{sum(usTrh,1)},0,daysall(i),{vocNm{j}},numel(replacements)+1,'VariableNames',{'t','Preds','trth','fips','PredictDay','Model','state'});
                Tt=[Tt;tbl];
            end
        end
        
        prdDt=readtable([frcsFl truthNm]);
        prdDt.hrzn=floor(-days(datetime(prdDt.forecast_date)-datetime(prdDt.target_end_date))/7+1);
        prdDt.WeekEnding=days(datetime(prdDt.target_end_date)-datetime('2019-10-15'));
        prdDt.Truth=nan(size(prdDt,1),1);
        prdDt(:,1:4)=[];
        outDZ=size(prdDt,1);
        
        
        FIPS=readtable([frcsFl 'Fips.xlsx']);
        save('VensimPreds','Tt','prdDt','FIPS');
        
        
        %% incorporate new predictions with the old ones
        
        dmean=1;
        prdSz=size(prdDt,1);
        prdDtN=table('Size',[numel(daysall)*numel(vocNm)*(numel(replacements)+1)*numel(lkhrz) 6],'VariableTypes',{'double','cellstr','cellstr','double','double','double'});
        prdDtN.Properties.VariableNames=prdDt.Properties.VariableNames;
        prdDt=[prdDt;prdDtN];
        
        
        
        for l=1:numel(replacements)+1
            %find valid data for each state
            fndDy=0;
            inxDt=Tt.state==l;
            inxDy=find(inxDt,sum(inxDt));
            cnt=1;
            while ~fndDy
                dya{l}=Tt.trth{inxDy(cnt)};
                if numel(dya{l})>480
                    dta{l}=Tt.t{inxDt};
                    fndDy=1;
                end
                
                cnt=cnt+1;
            end
        end
        
        inx=outDZ; %setting the pointer to start adding new model data after the previous ones
        for i=1:numel(daysall)
            
            for j=1:numel(vocNm)
                
                for l=1:numel(replacements)+1
                    dt=dta{l};
                    dy=dya{l};
                    inxD=find(dt==daysall(i)+1); %the first day of projection used in the model, as the index going into output vector
                    inxDV=daysall(i)+1; % the first day of projection in numerical sense
                    
                    
                    tinx=Tt.state==l & matches(Tt.Model,vocNm{j}) & Tt.PredictDay==daysall(i);
                    if sum(tinx)==1
                        y=Tt.Preds{tinx};
                        if numel(y)==numel(dy)
                            %create corrected prediction
                            crrctn=mean(dy(inxD-7:inxD-1))-mean(y(inxD-7:inxD-1));
                            yc=y;
                            
                            yc(inxD:end)=max(0,y(inxD:end)+crrctn,'includenan');
                            %populate prediction table
                            for m=1:numel(lkhrz)
                                if inxD-1+m*7<=numel(y)
                                    
                                    adjTx='';
                                    yz=y;
                                    
                                    inx=inx+1;
                                    prdDt.model(inx)={[vocNm{j} adjTx]};
                                    prdDt.value(inx)=sum(yz(inxD+(m-1)*7:inxD-1+m*7));
                                    
                                    prdDt.hrzn(inx)=m;
                                    prdDt.WeekEnding(inx)=inxDV-1+m*7;
                                    prdDt.Truth(inx)=sum(dy(inxD+(m-1)*7:inxD-1+m*7));
                                    if Tt.fips(tinx)==0
                                        prdDt.location(inx)={'US'};
                                    else
                                        prdDt.location(inx)={num2str(Tt.fips(tinx),'%0.2d')};
                                    end
                                    
                                    
                                end
                                
                            end
                            
                        end
                        % populate the truth values for all predictions available
                        for m=1:numel(lkhrz)
                            if inxD-1+m*7<=numel(dy)
                                
                                if Tt.fips(tinx)==0
                                    locNm='US';
                                else
                                    locNm=num2str(Tt.fips(tinx),'%0.2d');
                                end
                                inxF=matches(prdDt.location,locNm) & prdDt.hrzn==m & prdDt.WeekEnding==inxDV-1+m*7;
                                prdDt.Truth(inxF)=sum(dy(inxD+(m-1)*7:inxD-1+m*7));
                            end
                        end
                    end
                end
                for m=1:numel(lkhrz)
                    adjTx='';                 
                end
            end
            
        end
        
        prdDt(matches(prdDt.location,{''}),:)=[];
        % augment FIPS table with populations and prdDt with normalized errors
        prdDt.Error=abs(prdDt.value-prdDt.Truth);
        prdDt.WeekStart=prdDt.WeekEnding-prdDt.hrzn*7;
        
        Pops=readtable([frcsFl 'StatePopulations.xlsx']);
        for i=1:size(FIPS,1)
            FIPS.location(i)={num2str(FIPS.FIPS(i),'%0.2d')};
            FIPS.Pop(i)=Pops.Population(matches(Pops.Location,FIPS.Name(i)));
        end
        FIPS.Name(end+1)={'USA'};
        FIPS.location(end)={'US'};
        FIPS.Pop(end)=Pops.Population(matches(Pops.Location,'USA'));
        locs=unique(prdDt.location);
        for i=1:numel(locs)
            inx=matches(prdDt.location,locs{i});
            popVal=FIPS.Pop(matches(FIPS.location,locs{i}))/1e6;
            if ~isempty(popVal)
                prdDt.PopMil(inx)=popVal;
            end
            truAve=mean(prdDt.Truth(matches(prdDt.model,vocNm{1}) & prdDt.hrzn==1 & matches(prdDt.location,locs{i})));
            prdDt.AveDeath(inx)=truAve;
            
        end
        
        
        dData=readtable([frcsFl 'deathData.xlsx'],'VariableNamingRule','preserve');
        inxdd=[103:103+7*62-1]-99;
        prdT=reshape(dData.Time(inxdd),7,[]);
        prdTt=prdT(end,:)';
        wkNum=numel(inxdd)/7;
        nvPrd=table('Size',[(size(dData,2)-1)*wkNum 3],'VariableTypes',{'double','double','cellstr'});
        nvPrd.Properties.VariableNames={'Time','value','location'};
        pinx=1;
        for i=2:size(dData,2)
            nvPrd.value(pinx:pinx+wkNum-1)=sum(reshape(dData{inxdd,i},7,[]),1)';
            nvPrd.Time(pinx:pinx+wkNum-1)=prdTt;
            nvPrd.location(pinx:pinx+wkNum-1)=FIPS.location(matches(FIPS.Name,dData.Properties.VariableNames{i}));
            pinx=pinx+wkNum;
        end
        nvPrd(matches(nvPrd.location,{''}),:)=[];
        
        
        
        
        
        strtWk=prdDt.WeekStart;
        for i=1:numel(locs)
            inx=matches(prdDt.location,locs{i});  %all records for a given location
            for j=1:wkNum
                nvVal=nvPrd.value(nvPrd.Time==prdTt(j) & matches(nvPrd.location,locs{i}));
                if ~isempty(nvVal)
                    
                    inxNV=inx & strtWk==prdTt(j);
                    prdDt.Naive(inxNV)=nvVal;
                end
            end
            
        end
        
        
        
        % create the summary stats
        
        locs=unique(prdDt.location);
        hrz=unique(prdDt.hrzn);
        mods=unique(prdDt.model);
        hrzall=prdDt.hrzn(:);
        locsAll=prdDt.location(:);
        NerrAll=prdDt.Error(:)./prdDt.PopMil(:); %note that I am using the normalized error here, could have been the non-normalized though
        errAll=prdDt.Error(:);
        modAll=prdDt.model(:);
        valAll=prdDt.value(:);
        trthAll=prdDt.Truth(:);
        NVerrAll=(prdDt.Error(:)-abs(prdDt.Naive(:)-prdDt.Truth(:)))./prdDt.PopMil;
        MNerrAll=prdDt.Error(:)./prdDt.AveDeath(:)*100;
        wksall=prdDt.WeekEnding(:)-7*prdDt.hrzn(:); %starting weeks for projections
        wks=unique(strtWk);
        
        mtNum=10; %number of metrics stored
        mtrc=nan(numel(mods)+1,numel(wks),numel(locs),numel(hrz),mtNum);
        
        
        for i=1:numel(wks)
            for l=1:numel(locs)
                for j=1:numel(hrz)
                    inx=wksall==wks(i) & hrzall==hrz(j) & matches(locsAll,locs{l});
                    
                    if sum(inx)>1
                        Nerrs=NerrAll(inx);
                        errs=errAll(inx);
                        merrs=median(Nerrs);
                        mdl=modAll(inx);
                        vals=valAll(inx);
                        trts=trthAll(inx);
                        [~,sinx]=sort(Nerrs,'descend');
                        NVerrs=NVerrAll(inx);
                        MNerrs=MNerrAll(inx);
                        
                        for k=1:numel(Nerrs)
                            minx=find(matches(mods,mdl{k})); %the index of the model in the full vector of models
                            mtrc(minx,i,l,j,1)=numel(Nerrs)-1;   %the number of competitions for each model
                            mtrc(minx,i,l,j,2)=find(sinx==k)-1; %the number of wins for each model
                            mtrc(minx,i,l,j,3)=Nerrs(k)-merrs; %Exccess error compared to average in the comparison set
                            mtrc(minx,i,l,j,5)=vals(k); %Value for prediction
                            mtrc(minx,i,l,j,6)=trts(k); %Value for truth
                            mtrc(minx,i,l,j,7)=errs(k); %absolute error
                            mtrc(minx,i,l,j,8)=Nerrs(k); %normalized absolute error
                            mtrc(minx,i,l,j,9)=NVerrs(k); %normalized by population and naive predictor
                            mtrc(minx,i,l,j,10)=MNerrs(k);  %error normalized by location-spacific average of true deaths over time
                        end
                        mtrc(end,i,l,j,8)=merrs; %collecting median error for each setting as the last element in the models vector
                        
                    end
                end
            end
        end
        
        mtrc(:,:,:,:,4)=mtrc(:,:,:,:,2)./mtrc(:,:,:,:,1); %calculating win fraction for models
        
        popLocs=zeros(numel(locs),1);
        for i=1:numel(locs)
            popLocs(i)=unique(prdDt.PopMil(matches(prdDt.location,locs{i})));
        end
        
        % Read model categories
        modNm=readtable([frcsFl 'ModelNames.xlsx']);
        
        %test if results are reasonable
        inxLst=contains(prdDt.model,'NoAdh') & prdDt.Error./prdDt.PopMil>100;
        inxLst=matches(prdDt.model,'SEIRb') & prdDt.Error./prdDt.PopMil>100;
        sum(inxLst)
        GMod=groupcounts(prdDt(inxLst,:),{'model','location','WeekStart'});
        GMod(GMod.GroupCount>5,:)
        GLoc=groupcounts(prdDt(inxLst,:),'location')
        prdDt(inxLst,:);
        nanmean(prdDt.Error(inxLst)./prdDt.PopMil(inxLst))
        nanmedian(prdDt.Error(inxLst)./prdDt.PopMil(inxLst))
        
        prdDt.logHrz=log(prdDt.hrzn);
        prdDt.Week=categorical(prdDt.WeekEnding);
        prdDt.State=categorical(prdDt.location);
        
        
        save('VensimPreds','Tt','prdDt','FIPS','mtrc','Pops','mods','modNm','popLocs','dya');
        writetable(prdDt,'PredictionsTable.xlsx');
        
    else
        load('VensimPreds');
    end
end






if genRegs
    %
    vocExc={vocNm{[1,2,4]},'COVIDhub-ensemble'}; %list of models to exclude from ranking exercise
    if genRegDt
        
        
        inclTr=50; %inclusion threshold for number of data points to include a model in regressions
        prcTr=99; %inclusion threshold for precentile of error to include in regressions
        clear betanames;
        for i=1:numel(lkhrz)
            
            inxH=prdDt.hrzn==i;
            prdSmp=prdDt(inxH,:);
            
            G=groupcounts(prdSmp,'model');
            G(ismember(G.model,vocExc),:)=[]; %remove the models we don't need in ranking
            inxM=G.GroupCount>inclTr;
            [~,inxP]=ismember(prdSmp.model,G.model(inxM));
            
            prdSmp.LogNErr=log((0.01+prdSmp.Error)./prdSmp.PopMil);
            
            reginc=inxP>0;
            
            lme=fitlm(prdSmp(reginc,:),'LogNErr~1+model+location*Week');
            regT=[prdSmp(reginc,:),lme.Residuals];
            betanames{i}=table(lme.Coefficients.Properties.RowNames,'VariableNames',{'Name'});
            betanames{i}.vals=lme.Coefficients.Estimate;
            betanames{i}(contains(betanames{i}.Name,'Week_'),:)=[];
            betanames{i}(contains(betanames{i}.Name,'location_'),:)=[];
            if i==1
                betaAll=betanames{i};
                betaAll.Properties.VariableNames{2}='H1';
            else
                betaAll.New=NaN(size(betaAll,1),1);
                for j=1:size(betanames{i},1)
                    betaAll.New(matches(betaAll.Name,betanames{i}.Name(j)))=betanames{i}.vals(j);
                end
                betaAll.Properties.VariableNames{i+1}=['H' num2str(i)];
                
            end
        end
        save('regResults','betaAll','betanames');
    else
        load('regResults');
    end
    figSmp=2:size(betaAll,1);
    modNms={betaAll.Name(figSmp)};
    modNms=erase(modNms{:},'model_');
    
    mrkers=repmat(cellstr(['o+*s<x>p^dvh']'),[1,4]);
    lst=repmat({'-','--',':','-.'},[size(mrkers,1),1]);
    mrkSet=[reshape(lst,[],1),reshape(mrkers,[],1)];
    mrkSet=join(mrkSet,'');
    clrs={'r','m','g','b','k'};
    clrs={[1,0,0],[0,0.4,0.8],[0,0,0],[0,0.7,0],[0.9,0.7,0]};
    clNum=numel(clrs); %number of color codes to use
    clrCds=clNum*ones(numel(figSmp),1);  %this should be categories of models to be graphed
    for i=1:numel(clrCds)
        if sum(matches(modNm.Model,modNms{i}))>0
            clrCds(i)=modNm.TypeNum(matches(modNm.Model,modNms{i}));
        else
            clrCds(i)=3;
        end
    end
    
    
    
    X=[1:numel(lkhrz)];
    
    
    g=figure
    hold on
    cInx=zeros(clNum,1);
    cInx=[0 24 0 23 10];
    
    modInx=false(size(modNms));
    
    minErrs=nanmin(betaAll{2:end,2:end},[],1);
    minErrs=prctile(betaAll{2:end,2:end},10,1);
    for i=figSmp
        
        for j=1:clNum
            cInx(j)=cInx(j)+sum(clrCds(i-1)==j);
        end
        lTk=1;
        if matches(modNms(i-1),'SEIRb')
            lTk=2;
        end
        
        plt=plot(X,exp(betaAll{i,2:end}-minErrs),mrkSet{cInx(clrCds(i-1))},'Color',clrs{clrCds(i-1)},'LineWidth',lTk);%,'MarkerIndices',(i-1)*2+1:numel(modNms)*2:numel(X));
    end
    
    
    
    legend(replace(modNms(figSmp-1),'_','-'),'location','southoutside','NumColumns',4)
    ylim([0.7 3]);
    title('Estimated Error for Different Models','FontSize',15);
    xlabel('Projection Horizon');
    ylabel('Error Multiplier Compared to 90 Percentile Model');
    set(g, 'PaperUnits', 'inches');
    set(g, 'PaperPosition', [0 0 11 11]); %
    print(g,'g-ModelErrors','-djpeg','-r300');
    
    
    g=figure
    hold on
    [xx,yy]=sort(betaAll{2:end,2:end},1)
    yy(isnan(xx))=NaN(1)
    cInx=zeros(clNum,1);
    cInx=[0 24 0 23 10];
    
    legInx=false(size(figSmp));
    for i=figSmp
        
        for j=1:clNum
            cInx(j)=cInx(j)+sum(clrCds(i-1)==j);
        end
        [aa,~]=ind2sub(size(yy),find(ismember(yy,i-1)));
        lTk=1;
        if matches(modNms(i-1),'SEIRb')
            lTk=2;
        end
        if ~matches(modNms{i-1},vocExc{end})
        plt=plot(X(1:numel(aa)),aa,mrkSet{cInx(clrCds(i-1))},'Color',clrs{clrCds(i-1)},'LineWidth',lTk);%,'MarkerIndices',(i-1)*2+1:numel(modNms)*2:numel(X));
        legInx(i-1)=true;
        end
    end
    
    
    legNms=replace(modNms(figSmp-1),'_','-');
      legend(legNms{legInx},'location','southeast','NumColumns',3);
   
    yl=ylim;
    ylim([0 yl(2)]);
    title('Ranking of Different Models','FontSize',15);
    xlabel('Projection Horizon');
    ylabel('Projection Quality Rank');
    set(gca, 'YDir','reverse')
    set(g, 'PaperUnits', 'inches');
    set(g, 'PaperPosition', [0 0 12 9]); %
    print(g,'g-Ranking','-djpeg','-r300');

   
end

if genGraphs
    %% Create overall graphs summarizing key findings
    
    
    % set line and color codes
    
    mrkers=repmat(cellstr(['o+*s<x>p^dvh']'),[1,4]);
    lst=repmat({'-','--',':','-.'},[size(mrkers,1),1]);
    mrkSet=[reshape(lst,[],1),reshape(mrkers,[],1)];
    mrkSet=join(mrkSet,'');
    mrkSet2={'->','-p','-^','-d','-'};
    %clrs={'r','m','g','b','k'};
    clrs={[1,0,0],[0,0.4,0.8],[0,0,0],[0,0.7,0],[0.9,0.7,0]};
    clNum=numel(clrs); %number of color codes to use
    clrCds=clNum*ones(numel(mods),1);  %this should be categories of models to be graphed
    for i=1:numel(clrCds)
        if sum(matches(modNm.Model,mods{i}))>0
            clrCds(i)=modNm.TypeNum(matches(modNm.Model,mods{i}));
        else
            clrCds(i)=3;
        end
    end
    clrCds(end+1)=2;
    
    % create the graphs
    
    
    gTrs=100; %threshold of comparisons to include a model for a time horizon
    popTrs=1; %threshold to include given location population
    locInc=popLocs>popTrs;
    gMtr=[3 4 7 9 8 10];
    gMean=[0 1 0 0 0 0];
    ymx=[5 0.75 250 8 25 100];
    ymn=[-5 0.2 0 -8 0 0];
    difSmp=[0 0 0 0 1 0];
    gTits={'Normalized Error Relative to Median Model','Head-to-Head Win Fraction','Absolute Prediction Error','Normalized Error Relative to Constant Model','Impact of Model Components on Prediction Error','Error Normalized by Average Death in Each Location'};
    gTits2={'Median Normalized Error Minus Median Model (Death/Million/Week)','Win Fraction','Absolute Error (Death/Day)',...
        'Median Normalized Error Minus Constant Model (Death/Million/Week)','Median per Capita Error (Death/Million/Week)','Error Nomralized by Mean (%)'};
    sNm={'NrmMed','WinFrac','AbsErr','NrmNavi','NrmAbs','MNrmFrac'};
       hrz=unique(prdDt.hrzn);
    for k=5:numel(gMtr)
        
        mtInx=gMtr(k);  %metric index to be graphed
        g=figure;
        hold on
        cInx=zeros(clNum,1);
        cInx=[0 24 0 23 10];
        xAx=1:numel(hrz);
        modInx=false(size(mods));
        zcnt=0;
        for i=1:numel(mods)+difSmp(k)
            
            xInx=squeeze(nansum(squeeze(mtrc(i,:,locInc,:,8)),[1 2]))>gTrs;
            lTk=1;
            if difSmp(k)
                if i==numel(mods)+1
                    incSmp=true;
                else
                    incSmp=contains(mods(i),'SEIRb');
                end
                lTk=2;
            else
                incSmp=true;
            end
            if sum(xInx)>0 & incSmp
                zcnt=zcnt+1;
                for j=1:clNum
                    cInx(j)=cInx(j)+sum(clrCds(i)==j);
                end
                if gMean(k)
                    yOut=squeeze(nanmean(squeeze(mtrc(i,:,locInc,:,mtInx)),[1 2]));
                else
                    yOut=squeeze(nanmedian(squeeze(mtrc(i,:,locInc,:,mtInx)),[1 2]));
                end
                Xin=1:1/numel(mods):20;
                if sum(xInx)>1
                    if difSmp(k)
                        plot(Xin,interp1(xAx(xInx),yOut(xInx),Xin),mrkSet2{zcnt},'Color',clrs{clrCds(i)},'MarkerIndices',(i-1)*2+1:numel(mods)*2:numel(Xin),'LineWidth',lTk);%,'MarkerIndices',1:10:numel(X));
                        
                    else
                        plot(Xin,interp1(xAx(xInx),yOut(xInx),Xin),mrkSet{cInx(clrCds(i))},'Color',clrs{clrCds(i)},'MarkerIndices',(i-1)*2+1:numel(mods)*2:numel(Xin),'LineWidth',lTk);%,'MarkerIndices',1:10:numel(X));
                    end
                else
                    plot(xAx(xInx),yOut(xInx),mrkSet{cInx(clrCds(i))},'Color',clrs{clrCds(i)});%,'MarkerIndices',1:10:numel(X));
                end
                modInx(i)=true;
            end
        end
        if difSmp(k)
            legend({mods{modInx(1:end-1)},'Median Ensemble'},'location','southeast','NumColumns',3);
        else
            legend(mods(modInx),'location','southoutside','NumColumns',4);
        end
        ylim([ymn(k) ymx(k)]);
        title(gTits{k},'FontSize',15);
        xlabel('Projection Horizon (Weeks)');
        ylabel(gTits2{k});
        set(g, 'PaperUnits', 'inches');
        if difSmp(k)
        set(g, 'PaperPosition', [0 0 5.5 5]); %
        else
       set(g, 'PaperPosition', [0 0 11 11]); %
        end
        print(g,['g1-' sNm{k} ],'-djpeg','-r300');
    end
    %% quick US comparisons for a few models
    
    
    % graph showing performance of select models
    
    %    mdlSet={'noAdh','IHME','CovidComplete','Covid19Sim'};
    mdlSet={'SEIRb','IHME-CurveFit'};
    flg=zeros(1,numel(mdlSet));
    vocClr={':b','-.r','--g','-y'};
    locNm='US'; %this should be set for the location to be plotted
    weekNum=5;
    [weeks,~,~]=unique(prdDt.WeekStart);
    weekSet=weeks(1:round(numel(weeks)/weekNum):numel(weeks));
    locInx=matches(prdDt.location,locNm);
    trth=movsum(dya{end},[3 3]);
    xax=Tt.t{1};
    figure
    plot(xax,trth,'k','LineWidth',3);
    hold on
    hleglines=[];
    legs=[];
    for i=1:numel(weekSet)
        
        week=weekSet(i);
        weekend=week+7*(1:numel(lkhrz));%unique(prdDt.WeekEnding(prdDt.WeekStart==week & locInx));
        
        for j=1:numel(mdlSet)
            rcrds=matches(prdDt.model,mdlSet{j}) & prdDt.WeekStart==week & locInx;
            [~,yy]=ismember(prdDt.WeekEnding(rcrds),weekend);
            [xx,zz]=sort(yy);
            
            
            
            
            if sum(rcrds)>0
                yax=prdDt.value(rcrds);
                plt=plot(week+7*yy(zz(xx>0)),yax(zz(xx>0)),vocClr{j},'LineWidth',1.5);
                if ~flg(j)
                    hleglines(end+1)=plt(1);
                    legs{end+1}=mdlSet{j};
                    flg(j)=1;
                end
            end
        end
        
    end
    legend(hleglines,legs,'location','northwest')
    
    
    % Graph of our model's performance for a few locations
    
    mdlSet={'SEIRb'};
    
    vocClr={':b','-.r','--g','-y'};
    locNm={'USA'};%,'California','Texas','New York'}; %this should be set for the location to be plotted
    weekSet=[200,298,396,445];
    locInx=matches(prdDt.location,locNm);
    xax1=Tt.t{1}+datetime(2019,10,15);
    
    for k=1:numel(locNm)
        flg=zeros(1,numel(mdlSet));
        fipNm=FIPS.FIPS(matches(FIPS.Name,locNm{k}));
        g=figure
       % ax=gca;
        
        hleglines=[];
        legs=[];
        
        for i=1:numel(weekSet)
            
            week=weekSet(i);
            weekend=week+7*(1:numel(lkhrz));%unique(prdDt.WeekEnding(prdDt.WeekStart==week & locInx));
            
            for j=1:numel(mdlSet)
                
                rcrds=matches(Tt.Model,mdlSet{j}) & Tt.PredictDay==week & Tt.fips==fipNm;
                
                if i==1 & j==1
                    trth=movmean(Tt.trth{rcrds},[3 3]);
                    
                    fig1=plot(xax1,trth,'k','LineWidth',3);
                    ax=gca;
                    hold on
                end
                if sum(rcrds)==1
                    yax=Tt.Preds{rcrds};
                    xax=Tt.t{rcrds};
                    plt=plot(xax(xax>week)+datetime(2019,10,15),yax(xax>week),vocClr{j},'LineWidth',1.5);
                    if ~flg(j)
                        hleglines(end+1)=plt(1);
                        legs{end+1}=mdlSet{j};
                        flg(j)=1;
                    end
                end
            end
            
        end
      title('SEIRb Projections for USA Daily Deaths');
       
           ylabel('Daily Death Incidents');
        set(g, 'PaperUnits', 'inches');
       set(g, 'PaperPosition', [0 0 5.5 5]); %
        
        print(g,'g1-USADeaths','-djpeg','-r300');
    end
end
    
    
    
    
  